from __future__ import annotations

from copy import deepcopy
from datetime import date  # noqa: TCH003
from typing import TYPE_CHECKING, Callable, Literal
from uuid import uuid4

from langchain_core.messages import (
    BaseMessage,
    HumanMessage,
    SystemMessage,
    ToolMessage,
    trim_messages,
)
from langchain_core.prompts import ChatPromptTemplate
from langgraph.graph import END, START, MessagesState, StateGraph
from langgraph.prebuilt import ToolNode

from tablegpt.agent.output_parser import MarkdownOutputParser
from tablegpt.retriever import format_columns
from tablegpt.safety import create_hazard_classifier
from tablegpt.tools import (
    IPythonTool,
    markdown_console_template,
    process_content,
)
from tablegpt.utils import filter_contents

if TYPE_CHECKING:
    from pathlib import Path

    from langchain_core.agents import AgentAction, AgentFinish
    from langchain_core.language_models import BaseLanguageModel
    from langchain_core.retrievers import BaseRetriever
    from langchain_core.runnables import Runnable
    from pybox.base import BasePyBoxManager


INSTRUCTION = """You are TableGPT2, an expert Python data analyst developed by Zhejiang University. Your job is to help user analyze datasets by writing Python code. Each markdown codeblock you write will be executed in an IPython environment, and you will receive the execution output. You should provide results analysis based on the execution output.
For politically sensitive questions, security and privacy issues, or other non-data analyze questions, you will refuse to answer.

Remember:
- Comprehend the user's requirements carefully & to the letter.
- If additional information is needed, feel free to ask the user.
- Give a brief description for what you plan to do & write Python code.
- You can use `read_df(uri: str) -> pd.DataFrame` function to read different file formats into DataFrame.
- When creating charts, prefer using `seaborn`.
- DO NOT include images using markdown syntax (![]()) in your response under ANY circumstances.
- If error occurred, try to fix it.
- Response in the same language as the user.
- Today is {date}"""

PROMPT = ChatPromptTemplate.from_messages(
    [
        ("system", INSTRUCTION),
        ("placeholder", "{messages}"),
    ]
)


def get_data_analyzer_agent(llm: BaseLanguageModel) -> Runnable:
    return PROMPT | llm | MarkdownOutputParser(language_actions={"python": "python", "py": "python"})


class AgentState(MessagesState):
    # Current Date
    date: date

    # This is a bit of a hack to pass parent id to the agent state
    # But it act as the group id of all messages generated by the agent
    parent_id: str | None


def get_messages_truncation_config(
    blm: BaseLanguageModel | None,
    trim_message_method: Literal["default", "token_count"] = "default",
) -> tuple[
    Callable[[list[BaseMessage]], int] | Callable[[BaseMessage], int] | BaseLanguageModel,
    int,
]:
    match trim_message_method:
        case "default":
            return len, 20

        case "token_count":
            return blm, int((blm.metadata or {}).get("max_history_tokens", 4096))
        case _:
            e_msg = f"The expected value should be one of ['default', 'token_count'], but you provided {trim_message_method}."
            raise ValueError(e_msg)


def create_hazardous_hint_message(message: BaseMessage) -> BaseMessage | None:
    """Generate a hint message based on the message type and hazard category."""
    if isinstance(message, HumanMessage) and (category := message.additional_kwargs.get("hazard")):
        # TODO: "敏感话题"?
        content = f"用户问题可能涉及与 `{category}` 相关的敏感话题，请谨慎回答。"  # noqa: RUF001
        return SystemMessage(
            id=str(uuid4()),
            content=content,
        )
    return None


def create_data_analyze_workflow(
    llm: BaseLanguageModel,
    pybox_manager: BasePyBoxManager,
    *,
    workdir: Path | None = None,
    session_id: str | None = None,
    error_trace_cleanup: bool = False,
    vlm: BaseLanguageModel | None = None,
    safety_llm: Runnable | None = None,
    dataset_retriever: BaseRetriever | None = None,
    verbose: bool = False,
    trim_message_method: Literal["default", "token_count"] = "default",
) -> Runnable:
    """Creates a data analysis workflow for processing user input and datasets.

    This function constructs a state graph that orchestrates various tasks involved
    in analyzing data, including safety checks, column retrieval from datasets,
    and invoking the appropriate agent (either the standard or vision language model agent).
    The workflow is designed to handle multiple types of messages and responses.

    Args:
        llm (BaseLanguageModel): The primary language model for processing user input.
        pybox_manager (BasePyBoxManager):  A python code sandbox delegator, used to execute the data analysis code generated by llm.
        workdir (Path | None, optional): The working directory for `pybox` operations. Defaults to None.
        session_id (str | None, optional): An optional session identifier used to associate with `pybox`. Defaults to None.
        error_trace_cleanup (bool, optional): Flag to indicate if error traces should be cleaned up. Defaults to False.
        vlm (BaseLanguageModel | None, optional): Optional vision language model for processing images. Defaults to None.
        safety_llm (Runnable | None, optional): Model used for safety classification of inputs. Defaults to None.
        dataset_retriever (BaseRetriever | None, optional): Component to retrieve dataset columns based on user input. Defaults to None.
        verbose (bool, optional): Flag to enable detailed logging. Defaults to False.
        trim_message_method (Literal["default", "token_count"], optional): Determines the method used to trim the message. Defaults to "default".
            - "default": Applies the default trimming method (Truncate using the length of messages, default max length is 20).
            - "token_count": Use token count to truncate messages.
                Ensure the `BaseLanguageModel` has the `get_num_tokens_from_messages` method.
                And set `max_history_tokens` in `BaseLanguageModel.metadata`, e.g., {"max_history_tokens": 4096} (default 4096).
                You can specify the value using: `max_model_len (max tokens the service supports) - max_new_tokens (tokens needed for the request)`.

    Returns:
        Runnable: A runnable object representing the data analysis workflow.
    """
    agent = get_data_analyzer_agent(llm)

    vlm_agent = None
    if vlm is not None:
        vlm_agent = get_data_analyzer_agent(vlm)

    hazard_classifier = None
    if safety_llm is not None:
        hazard_classifier = create_hazard_classifier(safety_llm)

    tools = [
        IPythonTool(
            pybox_manager=pybox_manager,
            cwd=workdir,
            session_id=session_id,
            error_trace_cleanup=error_trace_cleanup,
        )
    ]
    tool_executor = ToolNode(tools)

    async def input_guard(
        state: AgentState,
    ) -> dict[str, list[BaseMessage]]:
        if hazard_classifier is not None:
            last_message = state["messages"][-1]
            flag, category = await hazard_classifier.ainvoke(input={"input": last_message.content})
            if flag == "unsafe" and category is not None:
                last_message.additional_kwargs["hazard"] = category
                return {"messages": [last_message]}
        return {"messages": []}

    async def retrieve_columns(state: AgentState) -> dict:
        if dataset_retriever is None:
            return {"messages": []}

        last_message = state["messages"][-1]
        docs = await dataset_retriever.ainvoke(
            input=last_message.content,
        )
        formatted = format_columns(docs)
        return {
            "messages": [
                SystemMessage(
                    id=str(uuid4()),
                    content=formatted,
                    additional_kwargs={"parent_id": state["parent_id"]},
                )
            ]
        }

    async def agent_node(state: AgentState) -> dict:
        # Truncate messages based on message count.

        token_counter, max_tokens = get_messages_truncation_config(llm, trim_message_method)

        messages = state["messages"][:]
        last_message = messages[-1]
        if safety_llm is not None and last_message.additional_kwargs.get("hazard") is not None:
            hazard_message = create_hazardous_hint_message(last_message)
            if hazard_message is not None:
                messages.append(hazard_message)

        windowed_messages = trim_messages(
            messages,
            token_counter=token_counter,
            max_tokens=max_tokens,
            start_on="human",  # This means that the first message should be from the user after trimming.
            # The system message is not in `messages`, so we don't need to specify `include_system`
        )
        # Keep only 'text' and 'table' content
        filtered_messages = filter_contents(windowed_messages, keep={"text", "table"})

        # Extract filename from attachments to content
        temp_messages = deepcopy(filtered_messages)
        for message in temp_messages:
            if attachments := message.additional_kwargs.get("attachments"):
                # TODO: We only support one attachment for now.
                message.content = f"文件名称: {attachments[0]['filename']}"

        agent_outcome: AgentAction | AgentFinish = await agent.ainvoke(
            {
                "messages": temp_messages,
                "date": state["date"],
            }
        )

        messages = []
        for message in agent_outcome.messages:
            message.additional_kwargs["parent_id"] = state["parent_id"]
            messages.append(message)
        return {"messages": messages}

    async def vlm_agent_node(state: AgentState) -> dict:
        # Truncate messages based on message count.
        token_counter, max_tokens = get_messages_truncation_config(vlm, trim_message_method)

        windowed_messages: list[BaseMessage] = trim_messages(
            state["messages"],
            token_counter=token_counter,
            max_tokens=max_tokens,
            start_on="human",  # This means that the first message should be from the user after trimming.
            # The system message is not in `messages`, so we don't need to specify `include_system`
        )
        # NOTE: This is hacky, but VLMs have limits on the number of images they can process.
        # First we keep only 'text' part for all windowed messages except the last one.
        filtered_messages = filter_contents(windowed_messages[:-1], keep={"text"})
        # Then we add the image content of the last message back, keep it under `max_support_images`.
        if isinstance(windowed_messages[-1].content, str):
            last_message = windowed_messages[-1]
        else:
            max_support_images = int((vlm.metadata or {}).get("max_support_images", 5))
            last_message: BaseMessage = deepcopy(windowed_messages[-1])
            last_message.content = []
            added = 0
            for part in reversed(windowed_messages[-1].content):
                if isinstance(part, str):
                    last_message.content.insert(0, part)
                    continue
                if part.get("type") == "text":
                    last_message.content.insert(0, part)
                    continue
                if part.get("type") == "image_url" and added < max_support_images:
                    last_message.content.insert(0, part)
                    added += 1
        filtered_messages.append(last_message)

        # Extract filename from attachments to content
        temp_messages = deepcopy(filtered_messages)
        for message in temp_messages:
            if attachments := message.additional_kwargs.get("attachments"):
                # TODO: We only support one attachment for now.
                message.content = f"文件名称: {attachments[0]['filename']}"

        agent_outcome: AgentAction | AgentFinish = await vlm_agent.ainvoke(
            {
                "messages": temp_messages,
                "date": state["date"],
            }
        )
        messages = []
        for message in agent_outcome.messages:
            message.additional_kwargs["parent_id"] = state["parent_id"]
            messages.append(message)
        return {"messages": messages}

    async def tool_node(state: AgentState) -> dict:
        messages: list[ToolMessage] = await tool_executor.ainvoke(state["messages"])
        for message in messages:
            message.additional_kwargs = message.additional_kwargs | {
                "parent_id": state["parent_id"],
            }
            # TODO: we assume our tool is only IPythonTool, so we can hardcode the format here.
            message.content = process_content(message.content)
            for part in message.content:
                if isinstance(part, dict) and part.get("type") == "text":
                    part["text"] = markdown_console_template.format(res=part["text"])
        return {"messages": messages}

    # I cannot use `END` as the literal hint, as:
    #  > Type arguments for "Literal" must be None, a literal value (int, bool, str, or bytes), or an enum value.
    # As `END` is just an intern string of "__end__" (See `langgraph.constants`), So I use "__end__" here.
    def should_continue(state: AgentState) -> Literal["tool_node", "__end__"]:
        # Must have at least one message when entering this router
        last_message = state["messages"][-1]
        if last_message.tool_calls:
            return "tool_node"
        return END

    def agent_selector(state: AgentState) -> Literal["agent_node", "vlm_agent_node"]:
        if vlm_agent is None:
            return "agent_node"

        # No messages yet. We should start with the default agent
        if len(state["messages"]) < 1:
            return "agent_node"

        # If the latest message contains "image/xxx" output,
        # the workflow graph shoud route to "vlm_agent"
        last_message = state["messages"][-1]
        for part in last_message.content:
            if isinstance(part, dict) and part.get("type") == "image_url":
                return "vlm_agent_node"
        return "agent_node"

    workflow = StateGraph(AgentState)

    workflow.add_node(input_guard)
    workflow.add_node(retrieve_columns)
    workflow.add_node(agent_node)
    workflow.add_node(vlm_agent_node)
    workflow.add_node(tool_node)

    workflow.add_edge(START, "input_guard")
    workflow.add_edge(START, "retrieve_columns")
    workflow.add_edge(["input_guard", "retrieve_columns"], "agent_node")

    workflow.add_conditional_edges("tool_node", agent_selector)
    workflow.add_conditional_edges("agent_node", should_continue)
    workflow.add_conditional_edges("vlm_agent_node", should_continue)
    return workflow.compile(debug=verbose)
